# -*- coding: utf-8 -*-
import numpy


def accuracy(out, labels):
    pred = numpy.argmax(out.data, axis=1)
    acc = numpy.mean(labels.data == pred)
    return acc * 100


def correct_num(out, labels):
    return numpy.sum(labels.data == numpy.argmax(out.data, axis=1))


def accuracy_numpy(out, labels):
    pred = numpy.argmax(out, axis=1)
    acc = numpy.mean(labels == pred)
    return acc * 100
